{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Algorithm DDPG\n",
    "****\n",
    "Randomly initialize critic network $Q(s,a|\\theta^Q)$ and actor $\\mu(s|\\theta^\\mu)$ with weights $\\theta^Q$ and $\\theta^\\mu$<br>\n",
    "initialize target network $Q'$ and $\\mu'$ with weights $\\theta^{Q'}\\leftarrow \\theta^Q, \\theta^{\\mu'} \\leftarrow \\theta^\\mu$<br>\n",
    "Initialize replay buffer $R$<br>\n",
    "**for** episode = 1, M **do**<br>\n",
    "&nbsp;&nbsp;&nbsp;Initialize a random process $N$ for action exploration<br>\n",
    "&nbsp;&nbsp;&nbsp;Receive initial observation state $s_1$<br>\n",
    "&nbsp;&nbsp;&nbsp;**for** t=1, T **do**<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Select action $a_t=\\mu(s_t|\\theta^\\mu)+N_t$ according to the current policy and exploration noise<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Excute action $a_t$ and observe reward $r_t$ and observe new state $s_{t+1}$<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Store transition $(s_t,a_t,r_t,s_{t+1})$ in $R$<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Sample a random minibatch of $N$ transitions $(s_i,a_i,r_i,s_{i+1})$ from $R$<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Set $y_i=r_i+\\gamma Q'(s_{i+1},\\mu'(s_{i+1}|\\theta^{\\mu'})|\\theta^{Q'})$<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Update critic by minimizing the loss: $L=\\frac{1}{N}\\sum_i(y_i-Q(s_i,a_i|\\theta^Q))^2$<br>\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Update the actor policy using the sampled policy gradient:\n",
    "$$\n",
    "\\nabla_{\\theta^\\mu}J\\approx\\frac{1}{N}\\sum_i\\nabla_aQ(s,a|\\theta^Q)|_{s=s_i,a=\\mu(s_i)}\\nabla_{\\theta^\\mu}\\mu(s|\\theta^\\mu)|s_i\n",
    "$$\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;Ubdate the target network:\n",
    "$$\n",
    "\\theta^{Q'}\\leftarrow \\tau \\theta^Q +(1-\\tau)\\theta^{Q'}\\\\\n",
    "\\theta^{\\mu'}\\leftarrow \\tau \\theta^\\mu+(1-\\tau)\\theta^{\\mu'}\n",
    "$$\n",
    "&nbsp;&nbsp;&nbsp;**end for**<br>\n",
    "**end for**<br>\n",
    "\n",
    "tips : $\\mu$ 表示确定策略<br>\n",
    "<br>\n",
    "### 代码表示Loss的计算:<br>\n",
    "* #### Critic Loss :\n",
    "$$\n",
    "\\frac{1}{N} \\sum_i \\overbrace{reward+(1.0-done) * gamma * target\\_value\\_net(next\\_state, target\\_policy\\_net(next\\_state))}^{target\\_value}- \\underbrace{value\\_net(state, action)}_{value}\n",
    "$$\n",
    "* #### Policy Loss :\n",
    "$$\n",
    "\\frac{1}{N} \\sum_i value\\_net(state, policy(state))\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 根据上边伪代码整理出以下表格:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center><style type=\"text/css\">\n",
    ".tg  {border-collapse:collapse;border-spacing:0;border-color:#bbb;}\n",
    ".tg td{font-family:Arial, sans-serif;font-size:14px;padding:10px 5px;border-style:solid;border-width:0px;overflow:hidden;word-break:normal;border-top-width:1px;border-bottom-width:1px;border-color:#bbb;color:#594F4F;background-color:#E0FFEB;}\n",
    ".tg th{font-family:Arial, sans-serif;font-size:14px;font-weight:normal;padding:10px 5px;border-style:solid;border-width:0px;overflow:hidden;word-break:normal;border-top-width:1px;border-bottom-width:1px;border-color:#bbb;color:#493F3F;background-color:#9DE0AD;}\n",
    ".tg .tg-koh6{font-size:36px;border-color:inherit;text-align:left;vertical-align:top}\n",
    ".tg .tg-7ab7{font-weight:bold;font-size:36px;border-color:inherit;text-align:center;vertical-align:top}\n",
    ".tg .tg-bwtg{font-size:36px;border-color:inherit;text-align:center;vertical-align:top}\n",
    ".tg .tg-n9dp{background-color:#C2FFD6;font-size:36px;border-color:inherit;text-align:left;vertical-align:top}\n",
    "</style>\n",
    "<table class=\"tg\">\n",
    "  <tr>\n",
    "    <th class=\"tg-7ab7\">Algorithm</th>\n",
    "    <th class=\"tg-7ab7\">Neural Network</th>\n",
    "    <th class=\"tg-7ab7\">input</th>\n",
    "    <th class=\"tg-7ab7\">output</th>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td class=\"tg-bwtg\" rowspan=\"2\">Critic</td>\n",
    "    <td class=\"tg-n9dp\">target critic network<br>θ^Q'</td>\n",
    "    <td class=\"tg-koh6\">next state s_{i+1} &amp;&amp;<br>the action a_{i+1} of s_{i+1} through θ^𝜇'</td>\n",
    "    <td class=\"tg-n9dp\">Q'(s_{i+1},a_{i+1})<br><big><big>for critic loss</big></big></td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td class=\"tg-n9dp\">critic network<br>θ^Q</td>\n",
    "    <td class=\"tg-koh6\">current state s_i &amp;&amp;<br>action in current state a_i</td>\n",
    "    <td class=\"tg-n9dp\">Q(s_i,a_i) <br><big><big>for critic loss and <br>policy loss</big></big></td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td class=\"tg-bwtg\" rowspan=\"2\">Actor</td>\n",
    "    <td class=\"tg-n9dp\">target actor network<br> θ^𝜇'</td>\n",
    "    <td class=\"tg-koh6\">next state s_{i+1}</td>\n",
    "      <td class=\"tg-n9dp\">action a_{i+1} at<br>state s_{i+1}<br><big><big>for critic loss</big></big></td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td class=\"tg-n9dp\">actor network <br>θ^𝜇</td>\n",
    "    <td class=\"tg-koh6\">current state s_i</td>\n",
    "    <td class=\"tg-n9dp\">𝜇(s_i)<br><big><big> for policy loss</big></big></td>\n",
    "  </tr>\n",
    "</table></center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../assets/DDPG_landscape.png\" width=600 />"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 特性总结\n",
    "1. 类似于DQN的replay buffer\n",
    "2. 类似于DQN的target network但此处更新的时候采用soft update\n",
    "3. 使用Actor_net输入s&&s'选择动作a&&a'分别用来更新actor_net和Q_net，使用Q_net输入a&&a'分别用来评估(s,a)和(s',a')来更新actor_net和Q_net\n",
    "4. 只能处理连续动作空间，离散动作不存在确定性策略的说法\n",
    "5. 标准化动作空间,对action 增加噪声\n",
    "6. actor-critic思想，但计算loss function的时候和A2C有很大不同<br>\n",
    "    a. A2C里的Critic loss $\\leftrightarrow$ TD error $\\leftrightarrow$ advantage function<br>\n",
    "    因为$\\hat{A}(s,a)=Q(s,a)-V(s)$是TD error的无偏估计。<br>\n",
    "    而DDPG里critic的计算和DoubleDQN很像，只不过计算target_Q的$a_{i+1}$是用网络$\\theta^{\\mu'}$所得的<br>\n",
    "    b.DDPG计算actor loss的时候估计是为了减少计算量，所以使用的是Q(s,a)而不是A(s,a)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Code Tips:\n",
    "**np.clip $\\leftrightarrow$ torch.clamp <br>**\n",
    "$y_i = \\begin{cases}\n",
    "        \\text{min} & \\text{if } x_i < \\text{min} \\\\\n",
    "        x_i & \\text{if } \\text{min} \\leq x_i \\leq \\text{max} \\\\\n",
    "        \\text{max} & \\text{if } x_i > \\text{max}\n",
    "    \\end{cases}$<br>\n",
    "    <br>\n",
    "**torch.detach() , torch.data:<br>**\n",
    "都是将自动微分截断，也就是进行requires_grad=False操作<br>\n",
    "推荐使用detach()，可以对微分进行追踪<br>\n",
    "DDPG这里，将target_Q.detach()保证只更新value_net<br>\n",
    "<br>\n",
    "**Policy_net输出未加F.tanh()**<br>\n",
    "将动作空间控制在-1~+1哇完了<br>\n",
    "tanh公式为:<br>\n",
    "$\\text{Tanh}(x) = \\tanh(x) = \\frac{e^x - e^{-x}}{e^x + e^{-x}}$\n",
    "<img src=\"../assets/TanhReal.gif\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MountainCarContinuous-v0:感觉该环境可以考虑用HER解决<br>\n",
    "但是我加的噪声是相同的量，就会永远给小车一个向前或者向后的推力,但是这样操作。减少step的话就会没有效果<br>\n",
    "假设这个力是向前的(因为我们最终需要到达对面的山顶)，那么我就可以保持在推力的方向上施加一个向量<br>\n",
    "这个向量有什么作用呢?它可以让我永远是向着目标前进的，即使奖励值很稀疏。但我能sample出十分有用的信息<br>\n",
    "那么问题来了?我要怎么去求这个单位向量，或者说怎么去拟合这个单位向量。nn?<br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from torch.distributions import Normal\n",
    "# from torch.utils.tensorboard import SummaryWriter\n",
    "from IPython.display import clear_output\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns; sns.set(color_codes=True)\n",
    "import pdb\n",
    "%matplotlib inline\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 经验回放\n",
    "class ReplayBuffer:\n",
    "    def __init__(self, column, batch_size, buffer_size):\n",
    "        self.current_state = np.zeros((buffer_size, column), dtype=np.float32)\n",
    "        self.next_state = np.zeros((buffer_size, column), dtype=np.float32)\n",
    "        init = lambda buffer_size : np.zeros(buffer_size, dtype=np.float32)\n",
    "        self.action = init(buffer_size)\n",
    "        self.reward = init(buffer_size)\n",
    "        self.done = init(buffer_size)\n",
    "        self.buffer_size, self.batch_size = buffer_size, batch_size\n",
    "        self.size, self.current_index = 0, 0\n",
    "    \n",
    "    def store(self, current_state, action, next_state, reward, done):\n",
    "        self.current_state[self.current_index] = current_state\n",
    "        self.action[self.current_index] = action\n",
    "        self.next_state[self.current_index] = next_state\n",
    "        self.reward[self.current_index] = reward\n",
    "        self.done[self.current_index] = done\n",
    "        self.current_index = (self.current_index + 1) % buffer_size\n",
    "        self.size = min(self.buffer_size, (self.size + 1))\n",
    "    \n",
    "    def sample(self):\n",
    "        index = np.random.choice(self.size, self.batch_size, replace=False)\n",
    "        return dict(current_state = torch.tensor(self.current_state[index],dtype=torch.float).to(device),\n",
    "                    action = torch.tensor(self.action[index]).reshape(-1, 1).to(device),\n",
    "                    next_state = torch.tensor(self.next_state[index],dtype=torch.float).to(device),\n",
    "                    reward = torch.tensor(self.reward[index]).reshape(-1, 1).to(device),\n",
    "                    done = torch.tensor(self.done[index]).reshape(-1, 1).to(device))\n",
    "    \n",
    "    def __len__(self):\n",
    "        return self.size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Wrapper 将输入数据进行预处理再输入到env中\n",
    "<img src=\"../assets/gym_wrapper.png\" align=center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "code_folding": [],
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "# 对动作空间进行标准化(对输入(_action)输出(_reverse_action)到环境中的action进行标准化)。\n",
    "class NormalizeActions(gym.ActionWrapper):\n",
    "    \n",
    "    def __init__(self, env):\n",
    "        super(NormalizeActions, self).__init__(env)\n",
    "        self.low_bound = self.action_space.low # -2\n",
    "        self.upper_bound = self.action_space.high # +2\n",
    "        \n",
    "    def action(self, action):\n",
    "        # 神经网络使用 tanh 输出的动作在 -1 ~ +1 之间\n",
    "        action = self.low_bound + (action + 1.0) * 0.5 * (self.upper_bound - self.low_bound)\n",
    "        action = np.clip(action, self.low_bound, self.upper_bound)\n",
    "        \n",
    "        return action\n",
    "    \n",
    "    def reverse_action(self, action):\n",
    "        # 对 action 进行反转(reverse)，\n",
    "        action = 2 * (action - self.low_bound) / (self.upper_bound - self.low_bound) - 1\n",
    "        action = np.clip(action, self.low_bound, self.upper_bound)\n",
    "        \n",
    "        return actions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用确定性策略，对action增加噪声\n",
    "class OUNoise:\n",
    "    def __init__(self, action_space, mu=0.0, theta=0.15, max_sigma=0.3, min_sigma=0.1, decay_period=100000):\n",
    "        self.mu = mu\n",
    "        self.theta = theta\n",
    "        self.sigma = max_sigma\n",
    "        self.max_sigma = max_sigma\n",
    "        self.min_sigma = min_sigma\n",
    "        self.decay_period = decay_period\n",
    "        self.action_dim = action_space.shape[0]\n",
    "        self.low = action_space.low\n",
    "        self.high = action_space.high\n",
    "        # get 定义立即执行函数\n",
    "        self.reset()\n",
    "    \n",
    "    def reset(self):\n",
    "        self.state = np.ones(self.action_dim) * self.mu\n",
    "    \n",
    "    def evolve_state(self):\n",
    "        # self.state += [0.0] + 0.3 * np.random.randn(1)\n",
    "        self.state = self.state + self.theta * (self.mu - self.state)+\\\n",
    "                        np.random.randn(self.action_dim) *np.sqrt(0.3) * self.sigma\n",
    "        return self.state\n",
    "    \n",
    "    def get_action(self, action, t=0):\n",
    "        # self.sigma = self.max_sigma\n",
    "        self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, t / self.decay_period)\n",
    "        return np.clip(action + self.evolve_state(), self.low, self.high)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 两种方式，一种是先训练一层之后再cat。另一种是直接cat\n",
    "class ValueNetwork(nn.Module):\n",
    "    def __init__(self, state_dim, action_dim):\n",
    "        super(ValueNetwork, self).__init__()\n",
    "        self.linear_state = nn.Linear(state_dim, 64)\n",
    "        self.linear_action = nn.Linear(action_dim, 64)\n",
    "        self.linear2 = nn.Linear(128, 32)\n",
    "        self.linear3 = nn.Linear(32, 1)\n",
    "    \n",
    "    def forward(self, state, action):\n",
    "        hidden_state = F.relu(self.linear_state(state))\n",
    "        hidden_action = F.relu(self.linear_action(action))\n",
    "        cat_state_action = torch.cat((hidden_action, hidden_state),dim=1)\n",
    "        hidden2 = F.relu(self.linear2(cat_state_action))\n",
    "        Q = self.linear3(hidden2)\n",
    "        return Q\n",
    "    \n",
    "#     def __init__(self, state_dim, action_dim):\n",
    "#         super(ValueNetwork, self).__init__()\n",
    "#         self.linear1 = nn.Linear(state_dim+action_dim, 256)\n",
    "#         self.linear2 = nn.Linear(256, 256)\n",
    "#         self.linear3 = nn.Linear(256, 1)\n",
    "    \n",
    "#     def forward(self, state, action):\n",
    "#         x = torch.cat((state, action),dim=1)\n",
    "#         x = F.relu(self.linear1(x))\n",
    "#         x = F.relu(self.linear2(x))\n",
    "#         Q = self.linear3(x)\n",
    "#         return Q\n",
    "\n",
    "class PolicyNetwork(nn.Module):\n",
    "    def __init__(self, in_dim, out_dim):\n",
    "        super(PolicyNetwork, self).__init__()\n",
    "        self.linear1 = nn.Linear(in_dim, 128)\n",
    "        self.linear2 = nn.Linear(128, 64)\n",
    "        self.linear3 = nn.Linear(64, out_dim)\n",
    "    \n",
    "    def forward(self, state):\n",
    "        x = F.relu(self.linear1(state))\n",
    "        x = F.relu(self.linear2(x))\n",
    "        x = torch.tanh(self.linear3(x))\n",
    "        return x\n",
    "    \n",
    "    def get_action(self, state):\n",
    "        state = torch.tensor(state,dtype=torch.float).unsqueeze(0).to(device)\n",
    "        action = self.forward(state)\n",
    "        return action.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ddpg_train(batch_size, gamma=0.99, soft_tau=1e-2):\n",
    "    samples = replay_buffer.sample()\n",
    "    state, action, next_state = samples['current_state'], samples['action'], samples['next_state']\n",
    "    reward, done = samples['reward'], samples['done']\n",
    "    \n",
    "    target_value = reward + (1.0-done)*gamma*target_value_net(next_state, target_policy_net(next_state))\n",
    "    value = value_net(state, action)\n",
    "    value_loss = ((value - target_value.detach()).pow(2)).mean()\n",
    "    \n",
    "    policy_loss = -value_net(state, policy_net(state)).mean()\n",
    "    \n",
    "    value_optimizer.zero_grad()\n",
    "    value_loss.backward()\n",
    "    value_optimizer.step()\n",
    "    \n",
    "    policy_optimizer.zero_grad()\n",
    "    policy_loss.backward()\n",
    "    policy_optimizer.step()\n",
    "    \n",
    "    for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):\n",
    "        target_param.data.copy_(target_param.data*(1.0-soft_tau) + param.data*soft_tau)\n",
    "    for target_param, param in zip(target_policy_net.parameters(), policy_net.parameters()):\n",
    "        target_param.data.copy_(target_param.data*(1.0-soft_tau) + param.data*soft_tau)\n",
    "    \n",
    "    return value_loss.item(), policy_loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = NormalizeActions(gym.make(\"Pendulum-v0\"))\n",
    "# env = NormalizeActions(gym.make(\"MountainCarContinuous-v0\"))\n",
    "ou_noise = OUNoise(env.action_space)\n",
    "\n",
    "in_dim = env.observation_space.shape[0]\n",
    "out_dim = env.action_space.shape[0]\n",
    "\n",
    "value_net = ValueNetwork(in_dim, out_dim).to(device)\n",
    "policy_net = PolicyNetwork(in_dim, out_dim).to(device)\n",
    "\n",
    "target_value_net = ValueNetwork(in_dim, out_dim).to(device)\n",
    "target_policy_net = PolicyNetwork(in_dim, out_dim).to(device)\n",
    "target_value_net.load_state_dict(value_net.state_dict())\n",
    "target_policy_net.load_state_dict(policy_net.state_dict())\n",
    "\n",
    "value_optimizer = optim.Adam(value_net.parameters())\n",
    "policy_optimizer = optim.Adam(policy_net.parameters(), lr=1e-4)\n",
    "\n",
    "buffer_size = 1000000\n",
    "batch_size = 128\n",
    "replay_buffer = ReplayBuffer(in_dim, batch_size, buffer_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "def smooth_plot(factor, item, plot_decay):\n",
    "    item_x = np.arange(len(item))\n",
    "    item_smooth = [np.mean(item[i:i+factor]) if i > factor else np.mean(item[0:i+1])\n",
    "                  for i in range(len(item))]\n",
    "    for i in range(len(item)// plot_decay):\n",
    "        item_x = item_x[::2]\n",
    "        item_smooth = item_smooth[::2]\n",
    "    return item_x, item_smooth\n",
    "    \n",
    "def plot(episode, rewards, value_losses, policy_losses):\n",
    "    clear_output(True)\n",
    "    rewards_x, rewards_smooth = smooth_plot(10, rewards, 500)\n",
    "    value_losses_x, value_losses_smooth = smooth_plot(10, value_losses, 30000)\n",
    "    policy_losses_x, policy_losses_smooth = smooth_plot(10, policy_losses, 30000)\n",
    "    \n",
    "    plt.figure(figsize=(18, 12))\n",
    "    plt.subplot(311)\n",
    "    plt.title('episode %s. reward: %s'%(episode, rewards_smooth[-1]))\n",
    "    plt.plot(rewards, label=\"Rewards\", color='lightsteelblue', linewidth='1')\n",
    "    plt.plot(rewards_x, rewards_smooth, label='Smothed_Rewards', color='darkorange', linewidth='3')\n",
    "    plt.legend(loc='best')\n",
    "    \n",
    "    plt.subplot(312)\n",
    "    plt.title('Value_Losses')\n",
    "    plt.plot(value_losses,label=\"Value_Losses\",color='lightsteelblue',linewidth='1')\n",
    "    plt.plot(value_losses_x, value_losses_smooth, \n",
    "             label=\"Smoothed_Value_Losses\",color='darkorange',linewidth='3')\n",
    "    plt.legend(loc='best')\n",
    "    \n",
    "    plt.subplot(313)\n",
    "    plt.title('Policy_Losses')\n",
    "    plt.plot(policy_losses,label=\"Policy_Losses\",color='lightsteelblue',linewidth='1')\n",
    "    plt.plot(policy_losses_x, policy_losses_smooth, \n",
    "             label=\"Smoothed_Policy_Losses\",color='darkorange',linewidth='3')\n",
    "    plt.legend(loc='best')\n",
    "    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "code_folding": [],
    "scrolled": false
   },
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-4ba0dacdc993>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     32\u001b[0m     \u001b[0mall_rewards\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepisode_reward\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m     \u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepisode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mall_rewards\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue_losses\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpolicy_losses\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-8-c53b74c2eb9e>\u001b[0m in \u001b[0;36mplot\u001b[0;34m(episode, rewards, value_losses, policy_losses)\u001b[0m\n\u001b[1;32m     12\u001b[0m     \u001b[0mrewards_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrewards_smooth\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msmooth_plot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrewards\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m500\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m     \u001b[0mvalue_losses_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue_losses_smooth\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msmooth_plot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue_losses\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m30000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m     \u001b[0mpolicy_losses_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpolicy_losses_smooth\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msmooth_plot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpolicy_losses\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m30000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m     \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfigsize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m18\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m12\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-8-c53b74c2eb9e>\u001b[0m in \u001b[0;36msmooth_plot\u001b[0;34m(factor, item, plot_decay)\u001b[0m\n\u001b[1;32m      2\u001b[0m     \u001b[0mitem_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     item_smooth = [np.mean(item[i:i+factor]) if i > factor else np.mean(item[0:i+1])\n\u001b[0;32m----> 4\u001b[0;31m                   for i in range(len(item))]\n\u001b[0m\u001b[1;32m      5\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m//\u001b[0m \u001b[0mplot_decay\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m         \u001b[0mitem_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mitem_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-8-c53b74c2eb9e>\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m      2\u001b[0m     \u001b[0mitem_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     item_smooth = [np.mean(item[i:i+factor]) if i > factor else np.mean(item[0:i+1])\n\u001b[0;32m----> 4\u001b[0;31m                   for i in range(len(item))]\n\u001b[0m\u001b[1;32m      5\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m//\u001b[0m \u001b[0mplot_decay\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m         \u001b[0mitem_x\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mitem_x\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36mmean\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.6/site-packages/numpy/core/fromnumeric.py\u001b[0m in \u001b[0;36mmean\u001b[0;34m(a, axis, dtype, out, keepdims)\u001b[0m\n\u001b[1;32m   3255\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3256\u001b[0m     return _methods._mean(a, axis=axis, dtype=dtype,\n\u001b[0;32m-> 3257\u001b[0;31m                           out=out, **kwargs)\n\u001b[0m\u001b[1;32m   3258\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3259\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.6/site-packages/numpy/core/_methods.py\u001b[0m in \u001b[0;36m_mean\u001b[0;34m(a, axis, dtype, out, keepdims)\u001b[0m\n\u001b[1;32m    132\u001b[0m             um.clip, a, min, max, out=out, casting=casting, **kwargs)\n\u001b[1;32m    133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 134\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0m_mean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    135\u001b[0m     \u001b[0marr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0masanyarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    136\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "episodes = 3000\n",
    "steps = 1000\n",
    "value_losses = []\n",
    "policy_losses = []\n",
    "all_rewards = []\n",
    "for episode in range(episodes):\n",
    "    state = env.reset()\n",
    "    ou_noise.reset()\n",
    "    episode_reward = 0\n",
    "    for step in range(steps):\n",
    "        action = policy_net.get_action(state)\n",
    "        action = ou_noise.get_action(action, step)\n",
    "        next_state, reward, done, _ = env.step(action.flatten())\n",
    "        \n",
    "        replay_buffer.store(state, action, next_state.flatten(), reward, done)\n",
    "        if len(replay_buffer) > batch_size :\n",
    "            value_loss, policy_loss = ddpg_train(batch_size)\n",
    "            value_losses.append(value_loss)\n",
    "            policy_losses.append(policy_loss)\n",
    "        \n",
    "        state = next_state\n",
    "        episode_reward += reward\n",
    "        \n",
    "        if done:\n",
    "            break\n",
    "    \n",
    "#     soft_loss = lambda loss : np.mean(loss[-100:]) if len(loss)>=100 else np.mean(loss)\n",
    "#     writer.add_scalar(\"Reward\", episode_reward, episode)\n",
    "#     writer.add_scalar(\"Value_Losses\", soft_loss(value_losses), episode)\n",
    "#     writer.add_scalar(\"Policy_Losses\", soft_loss(policy_losses), episode)\n",
    "    \n",
    "    all_rewards.append(episode_reward)\n",
    "    \n",
    "    plot(episode, all_rewards, value_losses, policy_losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "# torch.save(policy_net.state_dict(), \"./model/DDPG.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "code_folding": []
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD/CAYAAAAT33hZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAL4UlEQVR4nO3dX4icd7nA8e8mgZpWEJu2apqkG9R9CiGK2xTioSre1JumWIUeAjW9STFaencupBcqQiHYXKnrSWgQQgOBelM1N95JDQfFk9NAe/MYZdOkaTn5V4TS014kcy7mXd3sszsz+2d23k2+Hyiz8/vNZH/z0vnmfeedmYx1Oh0kabZ1o16ApPYxDJIKwyCpMAySCsMgqTAMkooNw/zDI2ICOAZsAq4C+zLz7DB/p6TlG/Yew2FgKjMngCngyJB/n6QVMLQwRMR9wCRwohk6AUxGxL3D+p2SVsYwDyW2Ahcz8zpAZl6PiHea8ct97nsH8DDwLnB9iGuUblfrgc8AfwE+mjs51NcYluFh4I+jXoR0G/gKcGru4DBfY7gA3B8R6wGay83NeD/vDnFdkv5l3ufa0MKQmZeAM8DeZmgv8Hpm9juMAA8fpNUy73Nt2IcSB4BjEfFD4D1g35B/n6QVMNbSj12PA9OjXoR0G9gOnJs76DsfJRWGQVJhGCQVhkFSYRgkFYZBUmEYJBWGQVJhGCQVhkFSYRgkFYZBUmEYJBWGQVJhGCQVhkFSYRgkFYZBUmEYJBWGQVJhGCQVhkFSYRgkFYZBUmEYJBWGQVJhGCQVhkFSYRgkFYZBUmEYJBWGQVJhGCQVhkFSYRgkFRv63SAiDgHfBsaBnZn5ZjM+ARwDNgFXgX2ZebbfnKT2G2SP4VXgq8Bbc8YPA1OZOQFMAUcGnJPUcn3DkJmnMvPC7LGIuA+YBE40QyeAyYi4t9fcyi1b0jAt9TWGrcDFzLwO0Fy+04z3mpO0Bvjio6RiqWG4ANwfEesBmsvNzXivOUlrwJLCkJmXgDPA3mZoL/B6Zl7uNbfcxUpaHWOdTqfnDSLiZ8C3gE8DV4CrmbkjIh6ke0ryk8B7dE9JZnOfBecGNA5ML+6hSFqC7cC5uYN9wzAi4xgGaTXMGwZffJRUGAZJhWGQVBgGSYVh0Io5vWcPp/fsGfUytAI8K6FlWygGD/3ud6u8Ei2BZyW08nrtIbj3sHYZBkmFYdCSDbJH4F7D2mQYJBWGQVJhGCQVhkFSYRgkFYZBUmEYJBWGQVJhGCQVhkFSYRgkFYZBUmEYJBWGQVJhGCQVhkFSYRgkFYZBUmEYJBWGQVJhGDRU/tsSa5NhkFQYBkmFYZBUGAZJxYZ+N4iITcDLwGeBj4C/Ad/NzMsRsRs4Amyk+w9jPpWZl5r7LTgnqd0G2WPoAD/NzMjMLwB/Bw5GxBhwHHg2MyeA14CDAL3mJLVf3zBk5rXM/MOsoT8BDwC7gA8z81Qzfhh4svm515xuAf6blLe2Rb3GEBHrgO8BvwW2AW/NzGXmFWBdRNzdZ05Sy/V9jWGOnwPvA78Anlj55Wit8I1Lt7aBwxARh4DPA3sy80ZEnKd7SDEzfw/QycxrveZWbukapUEPJQzI2jTQoUREvAA8BHwzMz9qhk8DGyPikeb6AeCVAeYktdwgpyt3AM8DfwX+KyIApjPziYj4DnAkIj5Gc0oSoNmjmHdOUvuNdTqdUa9hPuPA9KgXoYV5KHHL2E73L+6b+M5HSYVhkFQYBkmFYZBUGAZJhWGQVBgGSYVh0IrZdfIkz/35z6NehlbAYj9EJS3ovx977J8/7zp58qbrWlvcY9BQzERhbGxsxCvRUhgGDVWn0zEOa5Bh0NC19PM46sEwaFW417C2GAatCvca1hbPSmhgY2Nj8z7Bd508edP1hc5GHD16lP379w9lbVpZ7jFo1TzzzDOjXoIGZBg0kJ07dwKwY8eOEa9Eq8FDCQ3kjTfe4IMPPuDOO+9c8p/x0ksvreCKNEx+tZuWrN/Xu/m1bmuCX+2m5bt48eI/f+71xJ875+nKtcUwaFG2bNly0/X54jDfWEv3TLUADyW0aAudtlyp22tVeSihlbGYzz8YhbXJMGhJZp7sCwViZtworE2GQcsy88R//PHHGRsb49FHH71pXGuTrzFItzdfY5A0GMMgqTAMkgrDIKkwDJIKwyCpMAySCsMgqTAMkoqBvsEpIl6l+w6pG8D7wHOZeSYiJoBjwCbgKrAvM88291lwTlK7DbrH8HRmfjEzvwQcAn7VjB8GpjJzApgCjsy6T685SS02UBgy8x+zrn4CuBER9wGTwIlm/AQwGRH39ppbmWVLGqaBX2OIiKMRcR54AXga2ApczMzrAM3lO814rzlJLTdwGDJzf2ZuA54HXhzekiSN2qLPSmTmy8DXgbeB+yNiPUBzuRm40Py30Jyklusbhoj4eERsnXV9D3ANuAScAfY2U3uB1zPzcmYuOLeSi5c0HIOcrrwL+HVE3AVcpxuFPZnZiYgDwLGI+CHwHrBv1v16zUlqMb/BSbq9+Q1OkgZjGCQVhkFSYRgkFYZBUmEYJBWGQVJhGCQVhkFSYRgkFYZBUmEYJBWGQVJhGCQVhkFSYRgkFYZBUmEYJBWGQVJhGCQVhkFSYRgkFYZBUmEYJBWGQVJhGCQVhkFSYRgkFYZBUmEYJBWGQVJhGCQVhkFSYRgkFRsWc+OI+BHwY2BnZr4ZEbuBI8BG4BzwVGZeam674Jykdht4jyEiJoHdwPnm+hhwHHg2MyeA14CD/eYktd9AYYiIO4Ap4PtApxneBXyYmaea64eBJweYk9Ryg+4x/AQ4npnTs8a2AW/NXMnMK8C6iLi7z5yklusbhoj4MvAw8MvhL0dSGwyyx/A14EFgOiLOAVuA3wOfAx6YuVFE3AN0MvMa3dchFpqT1HJ9w5CZBzNzc2aOZ+Y48DbwDeBFYGNEPNLc9ADwSvPz6R5zklpuye9jyMwbwHeA/4yIs3T3LH7Qb05S+411Op3+t1p948B0vxtJWrbtdN9ndBPf+SipMAySCsMgqTAMkgrDIKkwDJIKwyCpMAySCsMgqTAMkgrDIKkwDJIKwyCpMAySCsMgqTAMkgrDIKkwDJIKwyCpMAySCsMgqTAMkgrDIKkwDJIKwyCpMAySCsMgqTAMkgrDIKkwDJIKwyCpMAySCsMgqTAMkgrDIKkwDJIKwyCpaGsY1o96AdJtYt7nWlvD8JlRL0C6Tcz7XBvrdDqrvZBB3AE8DLwLXB/xWqRb0Xq6UfgL8NHcybaGQdIItfVQQtIIGQZJhWGQVBgGSYVhkFQYBkmFYZBUbBj1AuaKiAngGLAJuArsy8yzI1zPIeDbwDiwMzPf7LfOUT2GiNgEvAx8lu6bVv4GfDczL0fEbuAIsBE4BzyVmZea+y04twprfhXYDtwA3geey8wzbdy+s9b8I+DHNP8/tHXbLkcb9xgOA1OZOQFM0d2oo/Qq8FXgrTnjvdY5qsfQAX6amZGZXwD+DhyMiDHgOPBss6bXgIMAveZWydOZ+cXM/BJwCPhVM97G7UtETAK7gfPN9TZv2yVrVRgi4j5gEjjRDJ0AJiPi3lGtKTNPZeaF2WO91jnKx5CZ1zLzD7OG/gQ8AOwCPszMU834YeDJ5udec0OXmf+YdfUTwI22bt+IuINuiL5PN8LQ4m27HK0KA7AVuJiZ1wGay3ea8Tbptc5WPIaIWAd8D/gtsI1ZezyZeQVYFxF395lbrbUejYjzwAvA07R3+/4EOJ6Z07PGWr1tl6ptYdDK+TndY/ZfjHoh/WTm/szcBjwPvDjq9cwnIr5M94N9vxz1WlZD28JwAbg/ItYDNJebm/E26bXOkT+G5gXTzwP/npk36B4PPzBr/h6gk5nX+sytqsx8Gfg68Dbt275fAx4EpiPiHLAF+D3wOdbAtl2sVoWhebX2DLC3GdoLvJ6Zl0e3qqrXOkf9GCLiBeAh4JuZOfNx2tPAxoh4pLl+AHhlgLlhr/XjEbF11vU9wDWgdds3Mw9m5ubMHM/Mcbrx+gbdPZzWbdvlat3HriPiQbqnoj4JvEf3VFSOcD0/A74FfBq4AlzNzB291jmqxxARO4A3gb8C/9cMT2fmExHxb3Rfvf8Y/zpt9r/N/RacG/J6PwX8BriL7vduXAP+IzP/p43bd87azwGPNacrW7dtl6t1YZA0eq06lJDUDoZBUmEYJBWGQVJhGCQVhkFSYRgkFYZBUvH/dUJ6/Qp1P48AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# policy_net_1 = PolicyNetwork(in_dim, out_dim)\n",
    "# policy_net_1.load_state_dict(torch.load(\"./model/DDPG.pth\"))\n",
    "# policy_net_1.eval()\n",
    "\n",
    "import pdb\n",
    "import gym\n",
    "from IPython import display\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "env = gym.make(\"Pendulum-v0\")\n",
    "state = env.reset()\n",
    "img = plt.imshow(env.render(mode='rgb_array')) # only call this once\n",
    "for _ in range(1000):\n",
    "    img.set_data(env.render(mode='rgb_array')) # just update the data\n",
    "    display.display(plt.gcf())\n",
    "    display.clear_output(wait=True)\n",
    "    policy_net = policy_net.cpu()\n",
    "    \n",
    "    action = policy_net(torch.FloatTensor(state)).detach().numpy()\n",
    "    # action = env.action_space.sample()\n",
    "    next_state, _, done, _ = env.step(action)\n",
    "    if done: \n",
    "        state = env.reset()\n",
    "    state = next_state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**发现训练了大概有12个小时还没完，是我的episode设置太多了，从tensorboard的图中看出大概1个小时就训练收敛了<br>\n",
    "中途停了，所以最后没有画出图来，还好tensorboard有记录**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* #### 初次训练的图:\n",
    "<img src=\"../assets/ddpg_first_train.png\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* #### 未加tanh的训练图:<br>\n",
    "**Reward :**\n",
    "<img src=\"../assets/Reward.png\"><br>\n",
    "**Value_Losses :**\n",
    "<img src=\"../assets/Value_Losses.png\"><br>\n",
    "**Policy_Losses :**\n",
    "<img src=\"../assets/Policy_Losses.png\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* #### 有加tanh的训练图:<br>\n",
    "**Reward :**\n",
    "<img src=\"../assets/Reward_with_tanh.png\"><br>\n",
    "**Value_Losses :**\n",
    "<img src=\"../assets/Value_Losses_with_tanh.png\"><br>\n",
    "**Policy_Losses :**\n",
    "<img src=\"../assets/Policy_Losses_with_tanh.png\">"
   ]
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "oldHeight": 691.0904539999999,
   "position": {
    "height": "712.525px",
    "left": "1171.09px",
    "right": "20px",
    "top": "194px",
    "width": "603.05px"
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "varInspector_section_display": "block",
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
